################################
## Name: 02_cutoff_sbert_embeddings.R
## Purpose: This script measures between article cosine cutoff
## on sbert embeddings and compares recall. We pass pairs below
## cutoff through sbert cross encoder. 
## Data In: 
## 1) bioweapons case study articles with sbert embeddings
## data/bioweapons_casestudy_5_20_2024_embeddings_sbert.json
## 2) recall articles
## data/recall_master_updated_11_24_2024_public.rds
## Data Out:
## list of potential matches to get cross encoder
## scores and gpt-4o annotations for, includes all unique pairs of 
## articles placed into same cluster 


library(tidyverse)
library(jsonlite)
library(igraph) 

## Set working directory to replication file

articles <- jsonlite::stream_in(file("data/bioweapons_casestudy_5_20_2024_sbert_embeddings.json"))
recall <- readRDS("data/recall_master_updated_11_24_2024_public.rds")

embeddings_roberta <- do.call("rbind",
                                articles$embeddings_roberta)
embeddings_base <- do.call("rbind",
                              articles$embeddings_base)

## code for cosine similarity 
cos_sim <- function(small, big) {
  small%*%t(big)/sqrt(tcrossprod(rowSums(small^2), rowSums(big^2)))
}

#######################################
#### Calculate Pairwise sim ########
#######################################


## 1) roberta 
## cosine similarity similarity
m_roberta <- cos_sim(embeddings_roberta,
                     embeddings_roberta)


## decompose into graph 
## We create an adjacency matrix to 
## represent each article as a node
## and the similarity between each 
## pair of articles as a tie
net_roberta <- igraph::graph_from_adjacency_matrix(m_roberta, 
                                                   mode = "upper",
                       weighted = TRUE,
                       diag = FALSE)

## we then transform this matrix 
## into an create edge list of unique ties
## weighted by cosine similarity
## pairs will be unique (unordered)
## because we only constructed the upper triangle 
## of the adjacency matrix 
edgelist_roberta <- as_edgelist(net_roberta)
edgelist_roberta <- as.data.frame(edgelist_roberta)

edgelist_roberta$weight <- E(net_roberta)$weight

## merge in information about each
## article in each pair
edgelist_roberta$ego_id <- articles$article_id[edgelist_roberta$V1]
edgelist_roberta$alter_id <- articles$article_id[edgelist_roberta$V2]

edgelist_roberta$ego_summary <- articles$summary[edgelist_roberta$V1]
edgelist_roberta$alter_summary <- articles$summary[edgelist_roberta$V2]

edgelist_roberta$ego_date <- articles$final_date[edgelist_roberta$V1]
edgelist_roberta$alter_date <- articles$final_date[edgelist_roberta$V2]

edgelist_roberta$ego_source <- articles$source_name[edgelist_roberta$V1]
edgelist_roberta$alter_source <- articles$source_name[edgelist_roberta$V2]

## 2) base model - repeat process for 
## MPNet embeddings
m_base <- cos_sim(embeddings_base,
                  embeddings_base)

## decompose into graph
net_base <- graph_from_adjacency_matrix(m_base, mode = "upper",
                               weighted = TRUE,
                               diag = FALSE)

## create edge list
edgelist_base <- as_edgelist(net_base)
edgelist_base <- as.data.frame(edgelist_base)

edgelist_base$weight <- E(net_base)$weight

## merge in information
edgelist_base$ego_id <- articles$article_id[edgelist_base$V1]
edgelist_base$alter_id <- articles$article_id[edgelist_base$V2]

edgelist_base$ego_summary <- articles$summary[edgelist_base$V1]
edgelist_base$alter_summary <- articles$summary[edgelist_base$V2]

edgelist_base$ego_date <- articles$final_date[edgelist_base$V1]
edgelist_base$alter_date <- articles$final_date[edgelist_base$V2]

edgelist_base$ego_source <- articles$source_name[edgelist_base$V1]
edgelist_base$alter_source <- articles$source_name[edgelist_base$V2]

########################
### Checking Recall ####
########################

## We merge in data on whether
## the pair in the edgelists was included
## in the recall dataset. 

## Before doing so we need to create 
## unique match IDs that will allow 
## us to merge the edge lists and the 
## recall dataset, regardless of the order
## of the pairs (whether article is ego 
## or alter), we do so by creating an ID
## where the articles in each pair are listed in
## alphabetical order

## create match ids
mat <- which(colnames(recall) == "article_id" |
               colnames(recall) == "focal_article")
recall$match_id <- apply(recall, 1, function(x){
  x <- x[mat]
  x <- x[order(x)]
  x <- paste0(x, collapse = "_")
  return(x)
})


mat <- which(colnames(edgelist_roberta) == "ego_id" |
               colnames(edgelist_roberta) == "alter_id")
## this will take some time to run (approx. 5 minutes): 
edgelist_roberta$match_id <- apply(edgelist_roberta, 1, function(x){
  x <- x[mat]
  x <- x[order(x)]
  x <- paste0(x, collapse = "_")
  return(x)
})

mat <- which(colnames(edgelist_base) == "ego_id" |
               colnames(edgelist_base) == "alter_id")

## this will take some time to run (approx. 5 minutes): 
a <- Sys.time()
edgelist_base$match_id <- apply(edgelist_base, 1, function(x){
  x <- x[mat]
  x <- x[order(x)]
  x <- paste0(x, collapse = "_")
  return(x)
})
b <- Sys.time()

## Here we check that we correctly
## created the match id: if so all
## recall paris should be in the edgelist datasets
'%ni%' <- Negate("%in%") 
sum(recall$match_id %ni% edgelist_roberta$match_id) ## 0
sum(recall$match_id %ni% edgelist_base$match_id) ## 0
## correct - none missing

## also checking none duplicated
sum(duplicated(edgelist_base$match_id)) ## 0 duplicated
sum(duplicated(edgelist_roberta$match_id)) ## 0

## also checking no self links
sum(edgelist_base$ego_id == edgelist_base$alter_id)
sum(edgelist_roberta$ego_id == edgelist_roberta$alter_id)

## create variable - is pair in recall dataset?
edgelist_roberta$recall <- ifelse(edgelist_roberta$match_id %in% recall$match_id,
                          "Recall Set", "Not\nRecall Set")
edgelist_base$recall <- ifelse(edgelist_base$match_id %in% recall$match_id,
                                  "Recall Set", "Not\nRecall Set")


## merge edgelists into single dataset
## for visualization purposes
edgelist_base$type <- "MPNet Base"
edgelist_roberta$type <- "STS Roberta Large"

combined_edgelist <- bind_rows(edgelist_base,
                               edgelist_roberta)


## Here we visualize the distribution
## of the cosine similarity scores by
## 1) embedding
## 2) whether the pair is in the recall set
ggplot(combined_edgelist,
       mapping = aes(x = weight,
                     fill = recall)) +
  geom_density(alpha = .5) +
  facet_wrap(~ type,
             ncol = 1) +
  theme_bw() +
  labs(x = "Cosine Similarity",
       y = "Density",
       fill = NULL)
## figures/embedding_candidate_compare_models.pdf
## We include this in the Supplemental Index,
## Figure A8

## Because we observed greater separation
## between recall and non-recall set in MPNet Base
## we use the MPNet base embeddings in our candidate 
## generation and set a threshold of .7
## for our bi-encoder step

## Here we create an annotated plot 
## with just the MPNet embeddings and with the
## threshold drawn at .7
## We include this plot in the SI, 
## Figure A6
ggplot(edgelist_base,
       mapping = aes(x = weight,
                     fill = recall)) +
  geom_density(alpha = .5) +
  theme_bw() +
  labs(x = "Between Article Summary Embedding Cosine Similarity",
       y = "Density",
       title = "Candidate Pairs Step 1: Bi-Encoder Cutoff",
       fill = NULL) +
  geom_vline(mapping = aes(xintercept = .7),
             lty = 2) +
  annotate(geom = "label",
           x = .25,
           y = 4,
           label = "Setting between embedding\ncosine similarity cutoff at .7\nrecalls 118/121 (97.5%) of recall set,\ndiscards 5.6 million\nor 93.6% of total possible pairs")
## figures/bi_encoder_recall.pdf

## Here we calculate 
## with the MPNet embeddings:
## 1) of all articles in the recall set
## what percent of articles are above the bi-encoder
## threshold of .7
## 2) of articles NOT in the recall set
## what percent of articles are discarded with this
## cutoff
## We present these findings in the main text
## section 4.2
edgelist_base %>%
  group_by(recall) %>%
  summarize(n = n(),
            below = sum(weight < .7),
            test_7 = sum(weight >= .7),
            prop = test_7 / n,
            discard = 1-prop)
## recall 118 / 121
118 / 121
## 97.5
## discard  5,699,472 (93.6%)
5699472 / nrow(edgelist_base)

## limit to pairs above
## .7 threshold
## we pass these pairs to the second
## step of candidate generation (cross encoder)
edgelist_run <- edgelist_base %>%
  filter(weight >= .7)

nrow(edgelist_run) ## 392,320


#saveRDS(edgelist_run,
#        "data/potential_matches_bioweapons_cosine_sbert.rds")

